WIP [ilu/ttx] optimize quant_moe_export#343
Conversation
Claude Code ReviewVerdict: Comment -- Switch from rank-1 int32 accum to fp16 tl.dot plus shared group-offset reuse looks reasonable, but there are correctness/perf concerns around BLOCK_K vs QUANT_GROUP_SIZE and the fc1 dtype change. SummaryThe PR replaces the per-K rank-1 int8 accumulation with a tiled fp16 Must fix
SuggestionsSuggestions (3)
Notes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the Triton int8 grouped matmul kernel to use matrix engine-based tl.dot by losslessly casting int8 operands to fp16, avoiding previous compiler bugs. It also introduces a performance optimization that computes group offsets and maximum token counts once and shares them across all kernel launches to minimize device-to-host synchronization overhead. Additionally, the intermediate fc1_out tensor is changed from float32 to the input precision to reduce memory bandwidth. Feedback suggests using torch.zeros instead of torch.empty followed by a manual zero-write to avoid an unnecessary host-to-device scalar write.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
0dddddc to
4204d31
Compare
Claude Code ReviewVerdict: Request changes -- Int8->fp16 dot path can lose precision when BLOCK_K exceeds the quant group size, and the autotuner key does not include QUANT_GROUP_SIZE so the prune hook may be ineffective. SummaryReplaces the rank-1 int32 accumulation in the grouped int8 MoE GEMM with an fp16 Must fix
SuggestionsSuggestions (3)
Notes
|
4204d31 to
b5946fa
Compare
Claude Code ReviewVerdict: Comment -- Switches int8 grouped GEMM from rank-1 int32 accumulation to fp16 tl.dot with int32 partials, and threads shared group offsets through the pipeline; correctness hinges on a couple of fp16/tile-size assumptions worth a closer look. SummaryThe PR replaces the inner per-K rank-1 int32 accumulation in Must fix
SuggestionsSuggestions (3)
Notes
|
b5946fa to
df91212
Compare
Claude Code ReviewVerdict: Comment -- Switch from scalar int32 outer-product to fp16 tl.dot with int32 round-trip is plausible but has a few correctness/perf concerns worth addressing. SummaryReplaces the rank-1 int32 accumulator loop in the grouped int8 MoE GEMM with a tiled Must fix
SuggestionsSuggestions (4)
Notes
|
|
生产里 group size 只有 {128, 256, 320, 512} 和 K(128 的倍数),唯一不能被 BLOCK_K 整除的是 320 与 BK=128。 |
Rank-1 accumulation → matrix-engine tl.dot The int8 GEMM previously fell back to a rank-1 outer-product accumulation because ILU's int8 tl.dot miscompiles (invalid bitcast → segfault). Now the int8 operands (|v| ≤ 127) are losslessly cast to fp16 and fed to an fp16 MMA; each BLOCK_K tile is computed exactly in the fp32 dot output, cast back to int32 and accumulated, then dequantized per group. This uses the matrix engine while staying bit-exact (autotune configs gained a BLOCK_K dim). ~6x speedup.
Deduplicated host→device syncs The four sub-steps (input quant / up-GEMM-swiglu / requant / down-GEMM) each recomputed group_offsets and called .max().item() — 4 device→host syncs that serialized the launches. Now _make_group_offsets() computes them once at the entry and threads them through all steps, leaving a single sync. ~28% faster.
Intermediate fc1_out fp32 → bf16 The SwiGLU output is stored as bf16 (matching ixformer's up-GEMM output) instead of fp32, halving the intermediate round-trip traffic. Accuracy-neutral; perf-neutral at the benchmarked (launch-bound) sizes but beneficial at larger batches and lower memory footprint.
Cumulative (24 experts, top_k=4, hidden=512, inter=1280, 97 tokens): int8 ~9.6 → 1.10 ms (~8.7x), int4 ~10.4 → 1.13 ms (~9.2x). All verified bit-accurate for both int8 and int4.